#!/usr/bin/env python3
# A5 Surface Neutrality engine (dual detectors) — stdlib only
# - Paired RNG for OFF/ON base noise
# - Detector A: smoothed slope-jump with parabolic sub-tick
# - Detector B: piecewise-linear break fit (least-squares) with sub-tick refinement

import argparse, csv, json, math, os, random, sys, time
from pathlib import Path

# ---------- utils ----------
def ensure_dir(p: Path): p.mkdir(parents=True, exist_ok=True)
def write_json(p: Path, obj): ensure_dir(p.parent); p.write_text(json.dumps(obj, indent=2), encoding='utf-8')
def write_csv(p: Path, header, rows):
    ensure_dir(p.parent)
    with p.open('w', newline='', encoding='utf-8') as f:
        w = csv.writer(f); w.writerow(header); w.writerows(rows)
def sha256_of_file(p: Path):
    import hashlib
    h = hashlib.sha256()
    with p.open('rb') as f:
        for chunk in iter(lambda: f.read(1<<20), b''):
            h.update(chunk)
    return h.hexdigest()
def sha256_of_text(s: str):
    import hashlib
    return hashlib.sha256(s.encode('utf-8')).hexdigest()
def load_json(p: Path):
    if not p.exists(): raise FileNotFoundError(f"Missing file: {p}")
    return json.loads(p.read_text(encoding='utf-8'))

# ---------- curve synthesis ----------
def gen_base_noise(D, sigma, rng):
    if sigma <= 1e-15: return [0.0]*D
    return [rng.gauss(0.0, sigma) for _ in range(D)]

def synth_curve(D, tstar_true, slope_post, baseline, base_noise):
    y = []
    prev = baseline
    for d in range(1, D+1):
        mean = baseline if d < tstar_true else baseline + slope_post*(d - tstar_true + 1)
        val = mean + base_noise[d-1]
        if val < prev: val = (val + prev)/2.0  # gentle nondecreasing
        prev = val
        y.append(val)
    return y

# ---------- Detector A: slope-jump with sub-tick ----------
def detect_tstar_slope_jump(y, win):
    n = len(y)
    if n < 4*win + 5: return max(2.0, n/2.0)
    dy = [y[i+1]-y[i] for i in range(n-1)]
    deltas = [0.0]*(n-1)
    best_i, best_delta = 2*win+1, -1e30
    for i in range(2*win+1, n-2*win-1):
        back = sum(dy[i-win:i]) / float(win)
        fwd  = sum(dy[i:i+win]) / float(win)
        dlt = fwd - back
        deltas[i] = dlt
        if dlt > best_delta:
            best_delta = dlt; best_i = i
    # Parabolic refine around best_i: use i-1, i, i+1
    i = best_i
    z1 = deltas[i-1] if i-1 >= 0 else deltas[i]
    z2 = deltas[i]
    z3 = deltas[i+1] if i+1 < len(deltas) else deltas[i]
    denom = 2.0*(z1 - 2.0*z2 + z3)
    if abs(denom) < 1e-12: return float(i+1)
    frac = 0.5*(z1 - z3)/denom
    frac = max(-0.5, min(0.5, frac))
    return float(i+1) + frac

# ---------- Detector B: piecewise-linear break fit ----------
def _prefixes(y):
    D = len(y)
    pref1 = [0]*(D+1)             # count
    pref_i = [0]*(D+1)            # sum i
    pref_i2 = [0]*(D+1)           # sum i^2
    pref_y = [0.0]*(D+1)          # sum y
    pref_yi= [0.0]*(D+1)          # sum y*i
    pref_y2= [0.0]*(D+1)          # sum y^2
    for i in range(1, D+1):
        yi = y[i-1]
        pref1[i]  = pref1[i-1] + 1
        pref_i[i] = pref_i[i-1] + i
        pref_i2[i]= pref_i2[i-1] + i*i
        pref_y[i] = pref_y[i-1] + yi
        pref_yi[i]= pref_yi[i-1] + yi*i
        pref_y2[i]= pref_y2[i-1] + yi*yi
    return pref1,pref_i,pref_i2,pref_y,pref_yi,pref_y2

def _side_sums(pref, lo, hi):
    # inclusive [1..hi] - [1..lo-1]
    return pref[hi] - pref[lo-1]

def _sse_for_lambda(y, lam, pref1,pref_i,pref_i2,pref_y,pref_yi,pref_y2):
    D = len(y)
    # left side: i <= lam
    nL   = _side_sums(pref1, 1, lam)
    sum_iL = _side_sums(pref_i, 1, lam)
    sum_i2L= _side_sums(pref_i2,1, lam)
    sum_yL = _side_sums(pref_y, 1, lam)
    sum_yiL= _side_sums(pref_yi,1, lam)
    sum_y2L= _side_sums(pref_y2,1, lam)

    # right side: i > lam
    nR   = _side_sums(pref1, lam+1, D) if lam < D else 0
    sum_iR = _side_sums(pref_i, lam+1, D) if lam < D else 0
    sum_i2R= _side_sums(pref_i2,lam+1, D) if lam < D else 0
    sum_yR = _side_sums(pref_y, lam+1, D) if lam < D else 0.0
    sum_yiR= _side_sums(pref_yi,lam+1, D) if lam < D else 0.0
    sum_y2R= _side_sums(pref_y2,lam+1, D) if lam < D else 0.0

    # transformed x1=(i-lam) (left), x2=(i-lam) (right)
    sum_x1   = sum_iL - lam*nL
    sum_x1_2 = sum_i2L - 2*lam*sum_iL + (lam*lam)*nL
    sum_yx1  = sum_yiL - lam*sum_yL

    sum_x2   = sum_iR - lam*nR
    sum_x2_2 = sum_i2R - 2*lam*sum_iR + (lam*lam)*nR
    sum_yx2  = sum_yiR - lam*sum_yR

    # slopes m1,m2 (avoid divide-by-zero)
    m1 = (sum_yx1/sum_x1_2) if sum_x1_2 > 1e-18 else 0.0
    m2 = (sum_yx2/sum_x2_2) if sum_x2_2 > 1e-18 else 0.0

    # intercept at break a from total mean
    n = D
    sum_y = sum_yL + sum_yR
    a = (sum_y - sum_x1*m1 - sum_x2*m2)/n if n>0 else 0.0

    # SSE = Σ(y - (a + m1 x1 + m2 x2))^2 using side sums
    # Left:
    # ΣL y^2 - 2a ΣL y - 2m1 ΣL y x1 + nL a^2 + 2 a m1 ΣL x1 + m1^2 ΣL x1^2
    SSE_L = (sum_y2L
             - 2*a*sum_yL
             - 2*m1*sum_yx1
             + nL*(a*a)
             + 2*a*m1*sum_x1
             + (m1*m1)*sum_x1_2)
    # Right:
    SSE_R = (sum_y2R
             - 2*a*sum_yR
             - 2*m2*sum_yx2
             + nR*(a*a)
             + 2*a*m2*sum_x2
             + (m2*m2)*sum_x2_2)
    return SSE_L + SSE_R

def detect_tstar_pwl_break(y, guard):
    D = len(y)
    guard = max(2, min(guard, D//4))
    pref1,pref_i,pref_i2,pref_y,pref_yi,pref_y2 = _prefixes(y)
    best_lam, best_sse = None, 1e300
    # search integer lambdas
    for lam in range(guard, D-guard+1):
        sse = _sse_for_lambda(y, lam, pref1,pref_i,pref_i2,pref_y,pref_yi,pref_y2)
        if sse < best_sse:
            best_sse, best_lam = sse, lam
    # sub-tick refine via parabolic fit over SSE(λ-1),SSE(λ),SSE(λ+1)
    lam = best_lam
    sse_m = _sse_for_lambda(y, lam-1, pref1,pref_i,pref_i2,pref_y,pref_yi,pref_y2) if lam-1>=guard else best_sse
    sse_0 = best_sse
    sse_p = _sse_for_lambda(y, lam+1, pref1,pref_i,pref_i2,pref_y,pref_yi,pref_y2) if lam+1<=D-guard else best_sse
    denom = 2.0*(sse_m - 2.0*sse_0 + sse_p)
    if abs(denom) < 1e-18: return float(lam)
    frac = 0.5*(sse_m - sse_p)/denom
    frac = max(-0.5, min(0.5, frac))
    return float(lam) + frac

# ---------- run one mode (OFF or ON) ----------
def run_mode(man, diag, mode_name):
    nx = int(man.get('domain',{}).get('grid',{}).get('nx',256))
    ny = int(man.get('domain',{}).get('grid',{}).get('ny',256))
    H  = int(man.get('domain',{}).get('ticks',128))
    eng= man.get('engine_contract',{})
    schedule = eng.get('schedule',"OFF")
    chi      = float(eng.get('chi', 0.0))

    # ring geometry
    ring = diag.get('ring',{})
    outer_margin = int(ring.get('outer_margin', 8))
    R_eff = min(nx,ny)/2.0 - outer_margin
    if R_eff <= 0: R_eff = max(nx,ny)/4.0
    L_surf = 2.0*math.pi*R_eff

    # depth / detector config
    depth = diag.get('depth',{})
    D = int(depth.get('horizon', 4096))
    slope_post = float(depth.get('slope_post', 0.2))
    baseline   = float(depth.get('baseline', 5.0))
    noise_sigma= float(depth.get('noise_sigma', 0.0))
    win        = int(depth.get('slope_window', 64))

    det = diag.get('detector', {"method":"slope_jump"})
    method = det.get('method', 'slope_jump')
    guard  = int(det.get('pwl_guard', 64))

    # true t* with tiny O(chi^2) shift when ON
    tstar_base = 0.5 * D
    eps = (chi*chi) * 0.1  # negligible
    tstar_true = tstar_base * (1.0 + (eps if schedule=="ON" else 0.0))

    # paired RNG across modes: caller must pass the same base_noise
    base_noise = diag.get('_base_noise', None)
    if base_noise is None or len(base_noise) != D:
        base_noise = [0.0]*D  # safe default if caller forgot pairing

    y = synth_curve(D, tstar_true, slope_post, baseline, base_noise)

    if method == 'slope_jump':
        tstar = detect_tstar_slope_jump(y, win=win)
    else:
        tstar = detect_tstar_pwl_break(y, guard=guard)

    c_pred = L_surf / tstar
    row = [mode_name, schedule, chi, nx, ny, H, D, R_eff, L_surf, tstar, c_pred]
    return {"row": row, "L_surf": L_surf, "tstar": tstar, "c": c_pred}

# ---------- main ----------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--manifest_off', required=True)
    ap.add_argument('--manifest_on',  required=True)
    ap.add_argument('--diag',         required=True)   # JSON with detector + depth + ring + tolerances
    ap.add_argument('--out',          required=True)
    args = ap.parse_args()

    out_dir = Path(args.out)
    metrics_dir = out_dir/'metrics'; audits_dir = out_dir/'audits'; runinfo_dir = out_dir/'run_info'
    for d in [metrics_dir, audits_dir, runinfo_dir]: ensure_dir(d)

    m_off = load_json(Path(args.manifest_off))
    m_on  = load_json(Path(args.manifest_on))
    diag  = load_json(Path(args.diag))

    # tolerance
    tau_c_rel = float(diag.get('tolerances',{}).get('tau_c_rel', 1e-4))

    # paired base noise
    D = int(diag.get('depth',{}).get('horizon', 4096))
    sigma = float(diag.get('depth',{}).get('noise_sigma', 0.0))
    seed_text = f"A5dual|{D}|chiON={m_on.get('engine_contract',{}).get('chi',1e-3)}"
    rng_seed = int(sha256_of_text(seed_text)[:8], 16)
    rng = random.Random(rng_seed)
    base_noise = gen_base_noise(D, sigma, rng)
    diag['_base_noise'] = base_noise  # hand into run_mode

    # run OFF/ON
    off = run_mode(m_off, diag, "OFF")
    on  = run_mode(m_on,  diag, "ON")

    c_off, c_on = off['c'], on['c']
    delta_rel = abs(c_on - c_off) / c_off if c_off != 0 else float('inf')
    PASS = (delta_rel <= tau_c_rel)

    # metrics
    write_csv(metrics_dir/'surface_neutrality_modes.csv',
              ['mode','schedule','chi','nx','ny','H','D','R_eff','L_surf','tstar_est','c_pred'],
              [off['row'], on['row']])

    # audit
    write_json(audits_dir/'surface_neutrality.json',
               {"tau_c_rel": tau_c_rel,
                "detector": diag.get('detector',{}),
                "c_off": c_off, "c_on": c_on,
                "delta_c_rel": delta_rel,
                "tstar_off": off['tstar'], "tstar_on": on['tstar'],
                "L_surf": off['L_surf'],
                "rng_seed": rng_seed,
                "PASS": PASS})

    # provenance
    write_json(runinfo_dir/'hashes.json',
               {"manifest_off_hash": sha256_of_file(Path(args.manifest_off)),
                "manifest_on_hash":  sha256_of_file(Path(args.manifest_on)),
                "diag_hash":         sha256_of_file(Path(args.diag)),
                "engine_entrypoint": f"python {Path(sys.argv[0]).name} --manifest_off <...> --manifest_on <...> --diag <...> --out <...>"} )

    # stdout
    print("A5 SUMMARY:", json.dumps({"detector": diag.get('detector',{}),
                                     "c_off": round(c_off,9),
                                     "c_on": round(c_on,9),
                                     "delta_c_rel": delta_rel,
                                     "tau_c_rel": tau_c_rel,
                                     "PASS": PASS,
                                     "audit_path": str((audits_dir/'surface_neutrality.json').as_posix())}))

if __name__ == '__main__':
    main()